"""
本文件是最终的预测文件，根据训练好的模型来预测
"""

import torch
import os
import datetime

from data_process import test_data_loader
from my_model import MyRnn
from config import *


# 1. 准备数据
test_loader = test_data_loader(batch_size=1)


# 2. 预测准备
net_name = "xxxxxx.pth"

net = MyRnn(100)
net.load_state_dict(torch.load(
    os.path.join(MODEL_SAVE_PATH, net_name)
))


# 3. 开始预测
now_time = datetime.datetime.now().strftime("%m-%d-%H-%M")
test_pbar = tqdm(test_loader, position=0, leave=True)
test_pbar.set_description(f"Testing...")

with open(os.path.join(RESULT_SAVE_PATH, f"xxxxx-{now_time}.csv"), "w") as f:
    f.write("id,tested_positive\n")
    for id, x in enumerate(test_pbar):
        y_hat = net(x)
        f.write(f"{id},{y_hat.item()}\n")
